import os
import sys
import copy
import argparse
from time import time
from tqdm import tqdm

import torch

from utils import *

parser = argparse.ArgumentParser()

parser.add_argument('--device', type=str, default='0', choices=['cpu', '0', '1', '2', '3'])

parser.add_argument('--dataset', type=str, default='cora',
                    choices=['cora', 'pubmed', 'ogbn-arxiv', 'ogbn-products'])
parser.add_argument('--feature', type=str, default="gpt_only_embedding",
                    choices=['raw', 'content_embedding', 'gpt_response_embedding', 'gpt_only_embedding'])

parser.add_argument('--model', type=str, default='SAGE', choices=['MLP', 'GCN', 'SAGE'])
parser.add_argument('--hidden_dim', type=int, default=256)
parser.add_argument('--num_layers', type=int, default=3)
parser.add_argument('--dropout', type=float, default=0)

args = parser.parse_args()

if args.device != 'cpu': args.device = 'cuda:' + args.device
if not torch.cuda.is_available(): args.device= 'cpu'

def get_split_idx(dataset_folder, dataset, seed=0):
    with open(f'{dataset_folder}{dataset}_{seed}.txt', 'r') as fin:
        train, valid, test = fin.read().strip().split('\n')
        train = [int(x) for x in train.split(' ')]
        valid = [int(x) for x in valid.split(' ')]
        test = [int(x) for x in test.split(' ')]
    return train, valid, test

def top_k_accuracy(logits, labels, k=1):
    """
    Computes the top-k accuracy for the given logits and labels.
    
    Args:
        logits (torch.Tensor): The logits output from the model of shape [num_samples, num_classes].
        labels (torch.Tensor): The ground truth labels of shape [num_samples].
        k (int): The value of k for top-k accuracy.
    
    Returns:
        float: The top-k accuracy.
    """
    # Get the top-k predictions
    _, top_k_predictions = logits.topk(k, dim=1, largest=True, sorted=True)
    
    # Check if the ground truth labels are in the top-k predictions
    correct = top_k_predictions.eq(labels.view(-1, 1).expand_as(top_k_predictions))
    
    # Compute the top-k accuracy
    top_k_acc = correct.float().sum().item() / labels.size(0)
    
    return top_k_acc

@torch.no_grad()
def evaluate(model, data, eval_func):
    model.eval()
    logits = model(data.x, data.edge_index)
    val_acc = eval_func(
        logits[data.val_mask], data.y[data.val_mask])
    test_acc = eval_func(
        logits[data.test_mask], data.y[data.test_mask])

    top2_train_acc = top_k_accuracy(logits[data.train_mask], data.y[data.train_mask], k=2)
    top2_val_acc = top_k_accuracy(logits[data.val_mask], data.y[data.val_mask], k=2)
    top2_test_acc = top_k_accuracy(logits[data.test_mask], data.y[data.test_mask], k=2)

    top3_train_acc = top_k_accuracy(logits[data.train_mask], data.y[data.train_mask], k=3)
    top3_val_acc = top_k_accuracy(logits[data.val_mask], data.y[data.val_mask], k=3)
    top3_test_acc = top_k_accuracy(logits[data.test_mask], data.y[data.test_mask], k=3)

    print(f"Top-2 Train Acc: {top2_train_acc}, Top-2 Val Acc: {top2_val_acc}, Top-2 Test Acc: {top2_test_acc}")
    print(f"Top-3 Train Acc: {top3_train_acc}, Top-3 Val Acc: {top3_val_acc}, Top-3 Test Acc: {top3_test_acc}")
    print()
    
    return val_acc, test_acc, logits

if args.model == "GCN":
    from GNNs.GCN.model import GCN as GNN
elif args.model == "SAGE":
    from GNNs.SAGE.model import SAGE as GNN
elif args.model == "MLP":
    from GNNs.MLP.model import MLP as GNN
else:
    exit(f"Model {args.model} is not supported!")

def eval_func(output, labels): # ACC
    preds = output.max(1)[1].type_as(labels)
    correct = preds.eq(labels).double()
    correct = correct.sum()
    return correct / len(labels)

def adjust_indices_list(sorted_indices_list, ground_truth):
    if sorted_indices_list[0] != ground_truth:
        if ground_truth in sorted_indices_list:
            sorted_indices_list.remove(ground_truth)
        sorted_indices_list.insert(0, ground_truth)
    return sorted_indices_list

split_folder = f'../raw_data/{args.dataset}/splits/'
dataset_folder = f'../raw_data/'
output_folder = f"../processed_data"
dataset2num_classes = {'cora': 7, 'pubmed': 3, 'ogbn-arxiv': 40, 'ogbn-products': 47}

from raw_data_utils.load import load_data
data, _, label2text = load_data(dataset=args.dataset, dataset_folder=dataset_folder)

num_classes = dataset2num_classes[args.dataset]
num_nodes = data.y.shape[0]

if args.feature == 'raw':
    pass
else:
    feature_file = f"../processed_data/{args.dataset}/{args.dataset}_{args.feature}_list.pt"
    new_x = torch.load(feature_file)
    assert new_x.shape[0] == data.x.shape[0]
    data.x = torch.from_numpy(new_x)

data.y = data.y.flatten()
data = data.to(args.device)

print(f"Device: {args.device}")
print(f"Dataset: {args.dataset}")
print(f"Num of nodes: {num_nodes}")
print(f"Num of classes: {num_classes}")
print(f"Feature: {args.feature}")

if args.dataset == 'cora' or args.dataset == 'pubmed':
    seeds = list(range(5))
else:
    seeds = [0]

adjusts = [True, False]

for seed in seeds:
    print(f"Seed: {seed}")
    for adjust in adjusts:

        if args.dataset in ['cora', 'pubmed']: # for these two datasets, we load the saved splits
            train_idx, valid_idx, test_idx = get_split_idx(split_folder, args.dataset, seed)
            data.train_mask = torch.tensor([x in train_idx for x in range(num_nodes)])
            data.val_mask = torch.tensor([x in valid_idx for x in range(num_nodes)])
            data.test_mask = torch.tensor([x in test_idx for x in range(num_nodes)])


        model = GNN(in_channels=data.x.shape[1],
                    hidden_channels=args.hidden_dim,
                    out_channels=num_classes,
                    num_layers=args.num_layers,
                    dropout=args.dropout).to(args.device)

        ckpt = f"saved_models/{args.dataset}_{seed}/{args.model}_{args.num_layers}_{args.hidden_dim}_{args.feature}.pt"
        model.load_state_dict(torch.load(ckpt))

        val_acc, test_acc, logits = evaluate(model, data, eval_func)

        temperature = 2.0
        scaled_logits = logits / temperature
        probabilities = torch.softmax(scaled_logits, dim=1)

        sorted_probs, sorted_indices = torch.sort(probabilities, descending=True, dim=1)
        sorted_probs_list = sorted_probs.tolist()
        sorted_indices_list = sorted_indices.tolist()

        if adjust:
            gt = data.y.tolist()
            error, total = 0, 0
            adjusted_sorted_indices_list = []
            for i in range(len(gt)):
                if sorted_indices_list[i][0] != gt[i]:
                    error += 1
                adjusted_sorted_indices_list.append(adjust_indices_list(sorted_indices_list[i], gt[i]))
                total += 1
            
            print(f"Error: {error}/{total}")

            sorted_pairs = [[(class_index, prob) for class_index, prob in zip(indices_row, probs_row)]
                        for indices_row, probs_row in zip(adjusted_sorted_indices_list, sorted_probs_list)]

                
            transformed_pairs = [[f'{label2text[class_index]}|{round(prob * 100, 2)}' for class_index, prob in sublist] for sublist in sorted_pairs]
            # sorted_pairs = [f'{label2text(class_index)}|{round(prob * 100, 2))}' for class_index, prob in sorted_pairs]    
            
            with open(f'{output_folder}/{args.dataset}/{args.dataset}_{seed}_gnn_output_probability_list.txt', 'w') as fout:
                for i in transformed_pairs:
                    fout.write('\t'.join(i)+'\n')
            # print(f"Adjusted output saved to {output_folder}/{args.dataset}/{args.dataset}_{seed}_gnn_output_probability_list.txt")
            
        else:
            sorted_pairs = [[(class_index, prob) for class_index, prob in zip(indices_row, probs_row)]
                        for indices_row, probs_row in zip(sorted_indices_list, sorted_probs_list)]
            transformed_pairs = [[f'{label2text[class_index]}|{round(prob * 100, 2)}' for class_index, prob in sublist] for sublist in sorted_pairs]
            
            with open(f'{output_folder}/{args.dataset}/{args.dataset}_{seed}_gnn_output_raw_probability_list.txt', 'w') as fout:
                for i in transformed_pairs:
                    fout.write('\t'.join(i)+'\n')
            
            # print(f"Raw output saved to {output_folder}/{args.dataset}/{args.dataset}_{seed}_gnn_output_raw_probability_list.txt")

